import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
import pandas as pd
import numpy as np

# 设置 PyTorch 随机数种子
torch.manual_seed(1234)

# 自定义 MNIST 数据集类
class MNISTDataset(Dataset):
    def __init__(self, path):
        # 加载数据集
        df = pd.read_csv(path)
        self.labels = np.asarray(df.iloc[:, 0])
        self.images = np.asarray(df.iloc[:, 1:]).reshape((-1, 28, 28))
        self.images = self.images / 255.0  # 将像素值缩放到 [0, 1]
        self.images = self.images.astype(np.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        # 将图像和标签转换为 PyTorch 张量
        image = torch.from_numpy(self.images[index])
        label = torch.tensor(self.labels[index], dtype=torch.long)
        return image, label

# 加载自定义 MNIST 数据集
train_dataset = MNISTDataset("./data/train.csv")
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 定义神经网络模型
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
for epoch in range(5):  # 进行 5 次迭代训练
    for i, (images, labels) in enumerate(train_loader):
        # 将输入数据展开为一维向量
        images = images.view(-1, 28 * 28)
        # 清空梯度
        optimizer.zero_grad()
        # 计算模型输出
        outputs = model(images)
        # 计算损失函数
        loss = criterion(outputs, labels)
        # 反向传播
        loss.backward()
        # 更新模型参数
        optimizer.step()
        # 打印训练日志
        if (i + 1) % 100 == 0:
            print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(
                epoch+1, 5, i+1, len(train_loader), loss.item()))

# 保存模型
torch.save(model.state_dict(), "model.ckpt")